import os
import h5py
import numpy as np
import argparse
from IPython import embed
from pytorch3d.loss import chamfer_distance
import torch
from pytorch3d.ops import cubify, sample_points_from_meshes
from scipy.optimize import linear_sum_assignment

def diversity(lines):
    d1 = 0.0
    d2 = 0.0
    uni_set = set()
    uni_num = 0
    bi_set = set()
    bi_num = 0
    for line in lines:
        flist = line.split(' ')
        for x in flist:
            uni_set.add(x)
            uni_num += 1
        for i in range(len(flist) - 1):
            bi_set.add(flist[i] + ' ' + flist[i+1])
            bi_num += 1
    d1 += len(uni_set) / uni_num
    d2 += len(bi_set) / bi_num
    print('diversity:', d1, d2)
    print('distinct sentences', len(set(lines))/float(len(lines)))


parser = argparse.ArgumentParser()

parser.add_argument('--save_name', type = str, required = True,
                    help='path to your trained DALL-E')
parser.add_argument('--category', type = str, required = True,
                    help='path to your trained DALL-E')
parser.add_argument('--emd', type = bool, default = False,
                    help='path to your trained DALL-E')
parser.add_argument('--ori', type = bool, default = False,
                    help='path to your trained DALL-E')

args = parser.parse_args()

load_dir = os.path.join('./outputs/dalle_outputs','test'+args.save_name+'_shapecaptioning')
load_dict = torch.load(os.path.join(load_dir, 'test'+args.save_name+'_'+args.category+'_shapecaptioning.pth'))
our_text = load_dict['our_text']
gt_text = load_dict['gt_text']

from collections import defaultdict
gt_text_flat = defaultdict(list)
for i in gt_text:
    for j in gt_text[i]:
        gt_text_flat[i].append(j[0])

our_text_final = {}
gt_text_final = {}
count = 0
for i in our_text:
    for j in our_text[i]:
        our_text_final[count] = [j]
        gt_text_final[count] = gt_text_flat[i]
        count+=1
import sys
sys.path.insert(0, '/home/tiangel/pytorch-GraphWriter/pycocoevalcap')
from bleu.bleu import Bleu
from cider.cider import Cider     

bleu = Bleu(4)
cider = Cider()
print(args.save_name, args.category)
bleu_scores = bleu.compute_score(gt_text_final, our_text_final)[0]
print('bleu_scores:', bleu_scores)
cider_scores = cider.compute_score(gt_text_final, our_text_final)[0]
print('cider_scores:', cider_scores)

# cider_scores = cider.compute_score(gt_text_final, our_text_final)
# from pytorch3d.io import save_ply
# #np.mean(cider_scores[1].reshape([-1, 48]), 1)
# #np.argmax(np.mean(cider_scores[1].reshape([-1, 48]), 1))
# # data1 = h5py.File('/home/tiangel/datasets/text2shape_pc_v1_val.h5', 'r')
# # data1 = h5py.File('/home/tiangel/datasets/abo_pc_v2_val.h5', 'r')
# data1 = h5py.File('/home/tiangel/datasets/shapeglot_pc_v2_val.h5', 'r')
# pcs = np.array(data1['data'])
# # rank = np.argsort(np.mean(cider_scores[1].reshape([-1, 48]), 1))
# rank = np.argsort(np.mean(cider_scores[1].reshape([-1, 8]), 1))
# for i in range(20):
#     save_ply("./final_captioning/sp%d.ply"%i, torch.Tensor(pcs[rank[-i]]))
# # save_ply("./final_captioning/abo486.ply", torch.Tensor(pcs[486]))
# embed()

our_lines = []
for j in list(our_text_final.keys()):
    our_lines.append(our_text_final[j][0])
diversity(our_lines)
gt_lines = []
for j in list(gt_text.keys()):
    for k in gt_text[j]:
        gt_lines.append(k[0])
diversity(gt_lines)
embed()
exit()

if args.ori:
    ours_shape_h5 = h5py.File(os.path.join('./shape2prog/output/', args.category, 'shapes.h5'), 'r')
    ours_shapes = np.array(ours_shape_h5['data'])
else:
    ours_shape_h5 = h5py.File(os.path.join('./shape2prog/vqprogram_outputs/', 'test'+args.save_name, 'pred', args.category+'.h5'), 'r')
    ours_shapes = np.array(ours_shape_h5['shape'])

target_shape_h5 = h5py.File(os.path.join('./shape2prog/data/', args.category + '_testing.h5'), 'r')
target_shapes = np.array(target_shape_h5['data'])

ours_pc_list = []
for i in range(ours_shapes.shape[0]):
    m1 = cubify(torch.Tensor(ours_shapes[i]).unsqueeze(0),0.5)
    p1 = sample_points_from_meshes(m1)
    # ours_pc_list.append(np.expand_dims(p1,0))
    ours_pc_list.append(p1)
ours_pc = np.vstack(ours_pc_list)

target_pc_list = []
for i in range(ours_shapes.shape[0]):
    m1 = cubify(torch.Tensor(ours_shapes[i]).unsqueeze(0),0.5)
    p1 = sample_points_from_meshes(m1)
    # target_pc_list.append(np.expand_dims(p1,0))
    target_pc_list.append(p1)
target_pc = np.vstack(target_pc_list)

cd_dis = chamfer_distance(torch.Tensor(ours_pc).cuda(), torch.Tensor(target_pc).cuda())[0]
print('cd_dis:', cd_dis)
if args.emd:
    emd_dis = []
    dim = 10000
    for i in range(ours_pc.shape[0]):
        print('emd',i)
        q1 = ours_pc[i]
        q2 = target_pc[i]
        t1 = np.repeat(q1,dim,axis=0).reshape(dim,dim,3)
        t2 = np.swapaxes(np.repeat(q2,dim,axis=0).reshape(dim,dim,3), 0, 1)
        diff = t1-t2
        matrix = diff[:,:,0]*diff[:,:,0]+diff[:,:,1]*diff[:,:,1]+diff[:,:,2]*diff[:,:,2]
        row_ind, col_ind = linear_sum_assignment(matrix)
        diff2=q1 - q2[col_ind]
        # diff2 = q1 - q2
        emd_dis.append(np.mean(np.sqrt(diff2[:,0]*diff2[:,0]+diff2[:,1]*diff2[:,1]+diff2[:,2]*diff2[:,2])))
print('emd_dis:', np.mean(np.array(emd_dis)))

embed()
